{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated Learning: Download Trained Model\n",
    "\n",
    "In the \"[Part 01 - Create Plan](Part%2001%20-%20Create%20Plan.ipynb)\" notebook we created the model, training plan, and averaging plan, and then hosted all of them in PyGrid.\n",
    "\n",
    "Imagine, such hosted FL model was trained using client libraries, SwiftSyft, KotlinSyft, syft.js, or FL client from the \"[Part 02 - Execute Plan](Part%2002%20-%20Execute%20Plan.ipynb)\" notebook.\n",
    "\n",
    "In this notebook, we'll download model checkpoints and test them against MNIST dataset.\n",
    "\n",
    "_NOTE_: Technically such evaluation is not correct since we don't have train/test split -\n",
    " clients train on randomly chosen samples from the MNIST dataset.\n",
    " However, since clients train only on a very small portion of samples,\n",
    " we can still get a sense of how well the model generalises to the rest of the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up Sandbox...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch as th\n",
    "from torch import nn\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import syft as sy\n",
    "from syft.grid.clients.model_centric_fl_client import ModelCentricFLClient\n",
    "from syft.grid.exceptions import GridError\n",
    "\n",
    "sy.make_hook(globals())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Utility function that sets tensors as model weights (copied from Part 01 notebook):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def set_model_params(module, params_list, start_param_idx=0):\n",
    "    \"\"\" Set params list into model recursively\n",
    "    \"\"\"\n",
    "    param_idx = start_param_idx\n",
    "\n",
    "    for name, param in module._parameters.items():\n",
    "        module._parameters[name] = params_list[param_idx]\n",
    "        param_idx += 1\n",
    "\n",
    "    for name, child in module._modules.items():\n",
    "        if child is not None:\n",
    "            param_idx = set_model_params(child, params_list, param_idx)\n",
    "\n",
    "    return param_idx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The model as in Part 01 notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 392)\n",
    "        self.fc2 = nn.Linear(392, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Load MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "mnist_dataset = th.utils.data.DataLoader(\n",
    "    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=batch_size,\n",
    "    drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Create client and model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (fc1): Linear(in_features=784, out_features=392, bias=True)\n",
       "  (fc2): Linear(in_features=392, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# PyGrid Node where the FL model was hosted and trained\n",
    "gridAddress = \"127.0.0.1:5000\"\n",
    "\n",
    "# Create FL client\n",
    "client = ModelCentricFLClient(id=\"test\", address=gridAddress)\n",
    "\n",
    "# Create model\n",
    "model = Net()\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Define evaluation helper function that will check model accuracy against whole MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate_model(name, version, checkpoint):\n",
    "    \"\"\"Test specified model against MNIST dataset\"\"\"\n",
    "    model_params_state = client.get_model(name, version, checkpoint)\n",
    "    model_params = model_params_state.tensors()\n",
    "\n",
    "    # Load model params into the model\n",
    "    set_model_params(model, model_params)\n",
    "\n",
    "    # Test\n",
    "    accuracies = []\n",
    "    for batch_idx, (X, y) in enumerate(mnist_dataset):\n",
    "        X = X.view(batch_size, -1)\n",
    "        with th.no_grad():\n",
    "            logits = model(X)\n",
    "        preds = th.argmax(logits, dim=1)\n",
    "        acc = preds.eq(y).float().mean()\n",
    "        accuracies.append(acc.item())\n",
    "\n",
    "    return np.mean(accuracies)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Let's get all model checkpoints and see how they were becoming better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing checkpoint 1...Done (0.10757403948772679)\n",
      "Testing checkpoint 2...Done (0.5063367129135539)\n",
      "Testing checkpoint 3...Done (0.6688400480256137)\n",
      "Testing checkpoint 4...Done (0.7043256403415155)\n",
      "Testing checkpoint 5...Done (0.705909818569904)\n",
      "Testing checkpoint 6...Done (0.7358090981856991)\n",
      "Testing checkpoint 7...No more checkpoints to try\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 6 artists>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQXUlEQVR4nO3df6zdd13H8eeLjiIIiNirwtrRBgpYBVm8DBOj/HCYzpHWxIFdojJlVowVBAS6YIapwUwgYkJqoMgYIYMy5q8ru9hEQEQE7AW6SdsUrqWwmxkpY0rQwCi8/eOckcPtOfd+b3vOvetnz0dysu/n+/3c73m/c7rX/Z7vOd/vTVUhSbrwPWitC5AkjYeBLkmNMNAlqREGuiQ1wkCXpEZctFZPvGHDhtq8efNaPb0kXZA+9alPfaWqpoZtW7NA37x5M3Nzc2v19JJ0QUryxVHbPOUiSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNWLMrRSVp3DbvvW2tS+jk1A1XTmS/HqFLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRnQK9CTbk5xIMp9k75Dtb0pypP/4XJL/Hn+pkqSlLPs99CTrgP3Ac4EF4HCSmao6dt+cqnrZwPzfAy6dQK2SxuiB/p3tFnW5sOgyYL6qTgIkOQjsBI6NmH818NrxlCfdfxiAur/rcsrlYuDOgfFCf91ZkjwO2AJ8aMT23UnmksydPn16pbVKkpbQJdAzZF2NmLsLuLWqvj1sY1UdqKrpqpqemhr6R6slSeeoS6AvAJsGxhuBu0bM3QW853yLkiStXJdAPwxsTbIlyXp6oT2zeFKSJwE/CHx8vCVKkrpYNtCr6gywBzgEHAduqaqjSfYl2TEw9WrgYFWNOh0jSZqgTrfPrapZYHbRuusXjf9ofGVJklbKK0UlqREGuiQ1wkCXpEb4J+g0EV5VKa0+j9AlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqRKdAT7I9yYkk80n2jpjzgiTHkhxN8u7xlilJWs6yf+AiyTpgP/BcYAE4nGSmqo4NzNkKXAf8TFXdk+SHJ1WwJGm4LkfolwHzVXWyqu4FDgI7F835LWB/Vd0DUFVfHm+ZkqTldAn0i4E7B8YL/XWDngg8McnHknwiyfZhO0qyO8lckrnTp0+fW8WSpKG6BHqGrKtF44uArcCzgKuBv0zyqLN+qOpAVU1X1fTU1NRKa5UkLaFLoC8AmwbGG4G7hsz5u6r6VlV9AThBL+AlSaukS6AfBrYm2ZJkPbALmFk052+BZwMk2UDvFMzJcRYqSVrasoFeVWeAPcAh4DhwS1UdTbIvyY7+tEPA3UmOAR8GXllVd0+qaEnS2Zb92iJAVc0Cs4vWXT+wXMDL+w9J0hrwSlFJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEZ0CPcn2JCeSzCfZO2T7NUlOJznSf1w7/lIlSUtZ9o9EJ1kH7AeeCywAh5PMVNWxRVPfW1V7JlDjA8LmvbetdQmdnLrhyrUuQdIIXY7QLwPmq+pkVd0LHAR2TrYsSdJKdQn0i4E7B8YL/XWL/XKSO5LcmmTTsB0l2Z1kLsnc6dOnz6FcSdIoXQI9Q9bVovHfA5ur6qnAPwLvHLajqjpQVdNVNT01NbWySiVJS+oS6AvA4BH3RuCuwQlVdXdVfbM/fBvwU+MpT5LUVZdAPwxsTbIlyXpgFzAzOCHJYwaGO4Dj4ytRktTFst9yqaozSfYAh4B1wI1VdTTJPmCuqmaAlyTZAZwBvgpcM8GaJUlDLBvoAFU1C8wuWnf9wPJ1wHXjLU2StBJeKSpJjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1olOgJ9me5ESS+SR7l5h3VZJKMj2+EiVJXSwb6EnWAfuBK4BtwNVJtg2Z9wjgJcAnx12kJGl5XY7QLwPmq+pkVd0LHAR2Dpn3x8DrgW+MsT5JUkddAv1i4M6B8UJ/3XcluRTYVFXvX2pHSXYnmUsyd/r06RUXK0karUugZ8i6+u7G5EHAm4BXLLejqjpQVdNVNT01NdW9SknSsroE+gKwaWC8EbhrYPwI4CeAf0pyCvhpYMYPRiVpdXUJ9MPA1iRbkqwHdgEz922sqv+pqg1VtbmqNgOfAHZU1dxEKpYkDbVsoFfVGWAPcAg4DtxSVUeT7EuyY9IFSpK6uajLpKqaBWYXrbt+xNxnnX9ZkqSV8kpRSWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqRGdAj3J9iQnkswn2Ttk+4uT/HuSI0n+Jcm28ZcqSVrKsoGeZB2wH7gC2AZcPSSw311VT6mqpwGvB/5s7JVKkpbU5Qj9MmC+qk5W1b3AQWDn4ISq+trA8PuBGl+JkqQuLuow52LgzoHxAvCMxZOS/C7wcmA98JxhO0qyG9gNcMkll6y0VknSErocoWfIurOOwKtqf1U9Hng18IfDdlRVB6pquqqmp6amVlapJGlJXQJ9Adg0MN4I3LXE/IPAL51PUZKklesS6IeBrUm2JFkP7AJmBick2TowvBL4/PhKlCR1sew59Ko6k2QPcAhYB9xYVUeT7APmqmoG2JPkcuBbwD3ACydZtCTpbF0+FKWqZoHZReuuH1h+6ZjrkiStkFeKSlIjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY3oFOhJtic5kWQ+yd4h21+e5FiSO5J8MMnjxl+qJGkpywZ6knXAfuAKYBtwdZJti6Z9BpiuqqcCtwKvH3ehkqSldTlCvwyYr6qTVXUvcBDYOTihqj5cVf/XH34C2DjeMiVJy+kS6BcDdw6MF/rrRnkR8IHzKUqStHIXdZiTIetq6MTkV4Fp4Jkjtu8GdgNccsklHUuUJHXR5Qh9Adg0MN4I3LV4UpLLgdcAO6rqm8N2VFUHqmq6qqanpqbOpV5J0ghdAv0wsDXJliTrgV3AzOCEJJcCb6UX5l8ef5mSpOUsG+hVdQbYAxwCjgO3VNXRJPuS7OhPewPwcOB9SY4kmRmxO0nShHQ5h05VzQKzi9ZdP7B8+ZjrkiStkFeKSlIjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY3oFOhJtic5kWQ+yd4h238uyaeTnEly1fjLlCQtZ9lAT7IO2A9cAWwDrk6ybdG0LwHXAO8ed4GSpG4u6jDnMmC+qk4CJDkI7ASO3Tehqk71t31nAjWeZfPe21bjac7bqRuuXOsSJD2AdDnlcjFw58B4ob9OknQ/0iXQM2RdncuTJdmdZC7J3OnTp89lF5KkEboE+gKwaWC8EbjrXJ6sqg5U1XRVTU9NTZ3LLiRJI3QJ9MPA1iRbkqwHdgEzky1LkrRSywZ6VZ0B9gCHgOPALVV1NMm+JDsAkjw9yQLwfOCtSY5OsmhJ0tm6fMuFqpoFZhetu35g+TC9UzGSpDXilaKS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWpEp0BPsj3JiSTzSfYO2f6QJO/tb/9kks3jLlSStLRlAz3JOmA/cAWwDbg6ybZF014E3FNVTwDeBPzpuAuVJC2tyxH6ZcB8VZ2sqnuBg8DORXN2Au/sL98K/HySjK9MSdJyUlVLT0iuArZX1bX98a8Bz6iqPQNzPtufs9Af/0d/zlcW7Ws3sLs/fBJwYlyNjMEG4CvLzrqwtNZTa/1Aez211g/c/3p6XFVNDdtwUYcfHnakvfi3QJc5VNUB4ECH51x1Seaqanqt6xin1npqrR9or6fW+oELq6cup1wWgE0D443AXaPmJLkI+AHgq+MoUJLUTZdAPwxsTbIlyXpgFzCzaM4M8ML+8lXAh2q5czmSpLFa9pRLVZ1Jsgc4BKwDbqyqo0n2AXNVNQO8HXhXknl6R+a7Jln0hNwvTwWdp9Z6aq0faK+n1vqBC6inZT8UlSRdGLxSVJIaYaBLUiMM9CGSvCfJ5iS/n2TXwPo9/dsbVJINa1njSi3R08392zp8NsmNSR68lnV2tUQ/b09ye5I7ktya5OFrWedKjOppYPubk3x9LWo7F0u8Rjcl+UKSI/3H09ayzpVYoqckeV2SzyU5nuQla1GfgT7clqo6BTwT+OjA+o8BlwNfXIuiztOonm4Gngw8BXgocO3ql3ZORvXzsqr6yap6KvAlYM+wH76fGtUTSaaBR61FUedhZD/AK6vqaf3HkdUv7ZyN6ukael/dfnJV/Ri9K+pXXXOBnuTX+0dntyd5V3/dTUnekuSj/d+gzxvxszcnOQY8KckR4BeA25JcC1BVn+m/mKtqwj3NVh/wb/SuM7iQ+/laf17o/YJalU/9J9lTevdTegPwqtXopf+cE+tnrUy4p98B9lXVdwCq6sur0NLZqqqZB/Dj9G4nsKE/fnT/vzcB/0DvF9hWehdCfd+IfbwAeAWwGXjfiDmn7nuOhnp6MPBp4Gcv9H6AdwD/BXwYeNiF/hoBL6X3zgPg6w30c1N//3fQu5nfQxro6W7gNcAc8AFg66R7GvZo7Qj9OcCt1b+HTFUNXq16S1V9p6o+D5ykd5phmEuBI/ROQdwf3gquVk9/AfxzVS1+azxuE++nqn4DeCxwHPiVMdY+ysR6SvJY4PnAmydR+AiTfo2u6//c04FHA68eY+2jTLqnhwDfqN4tAt4G3DjO4rvqci+XC0kY/RZ78frvGSf5ReBPgC3A84Ap4H+TXF5Vzx53oSsw8Z6SvLa/7bfHVfQSVuU1qqpvJ3kv8Ep6R+yTNMmeLgWeAMz3ziLxsCTz1btV9aRM9DWqqv/sT/9mkncAfzCuwpcw6X93C8Bf9Zf/hsn/mxtuLd4WTPht1eeAHxrytmqW3tuqxzPibRW9K2E/1l/+IPDIEc9zitU95TKxnuh9CPqvwEMv9H7o/U/7hIHlNwJvvJB7GjJ3tU65TPLf3GMGXqM/B25ooKcbgN/sLz8LODzpnoY9mjpCr94tCV4HfCTJt4HP0Pv0GXrnzz4C/Ajw4qr6xpBdXArcnt49ax5c/Q/Y7tP/KtKrgB8F7kgyW/3bCk/KpHsC3kLvWzsf7x8B/nVV7Rt/Jz0T7ifAO5M8sr98O70PqyZqFV6jVbUK/dycZIrea3QEePEE2vgeq9DTDfT6ehnwddbo22IPiEv/k9wEvL+qbl3rWsaltZ5a6wfa66m1fqC9nlr7UFSSHrAeEEfokvRA4BG6JDXCQJekRhjoktQIA12SGmGgS1Ij/h+Hl8GWEPpiYQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "name = \"mnist\"\n",
    "version = \"1.0.0\"\n",
    "checkpoint = 1\n",
    "\n",
    "checkpoints = []\n",
    "accuracies = []\n",
    "\n",
    "while True:\n",
    "    try:\n",
    "        print(f\"Testing checkpoint {checkpoint}...\", end=\"\")\n",
    "        accuracy = evaluate_model(name, version, checkpoint)\n",
    "        print(f\"Done ({accuracy})\")\n",
    "        checkpoints.append(f\"cp #{checkpoint}\")\n",
    "        accuracies.append(accuracy)\n",
    "        checkpoint += 1\n",
    "    except GridError as err:\n",
    "        # Model not found\n",
    "        print(\"No more checkpoints to try\")\n",
    "        break\n",
    "\n",
    "plt.bar(checkpoints, accuracies)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "NOTE: Retrieving all checkpoints is done for the sake of model improvement demonstration.\n",
    "To simply get the latest checkpoint, do `client.get_model(name, version)` or `client.get_model(name, version, \"latest\")`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
